"""
Phase 7: Referee-requested robustness for gravity paper.

Implements:
  A. Two-way clustered SEs (by reporter and partner) for Models 2b, 2c
  B. Pair FE + year FE OLS (within-pair identification)
  C. PPML with exporter×year and importer×year FE (structural gravity)
     - Demographics identified via ΔZ × time-invariant bilateral interactions
  D. Zero accounting (how many obs lost in log spec)
  E. Economic magnitudes (SD-based)

Output: gravity_bilateral/output/tables/referee_robustness.csv
"""

import pandas as pd
import numpy as np
from pathlib import Path
import sys
from scipy import stats
import statsmodels.api as sm

sys.path.insert(0, str(Path("/mnt/c/demographics_capital_flows/gravity_bilateral")))
from src.model import PanelGLS

BASE_DIR = Path("/mnt/c/demographics_capital_flows/gravity_bilateral")
PROCESSED_DIR = BASE_DIR / "data" / "processed"
OUTPUT_DIR = BASE_DIR / "output" / "tables"


def two_way_cluster_se(y, X, beta, reporters, partners):
    """
    Cameron-Gelbach-Miller two-way clustered SEs.

    V_twoway = V_reporter + V_partner - V_intersection
    where intersection clusters are reporter×partner (= pair).

    Each V_c = (X'X)^{-1} B_c (X'X)^{-1}
    B_c = sum over clusters of (X_c' e_c)(e_c' X_c)
    """
    n, k = X.shape
    resid = y - X @ beta
    XtX_inv = np.linalg.inv(X.T @ X)

    def meat_matrix(cluster_ids):
        """Compute sandwich meat for a given clustering."""
        clusters = np.unique(cluster_ids)
        B = np.zeros((k, k))
        for c in clusters:
            mask = cluster_ids == c
            Xc = X[mask]
            ec = resid[mask]
            score = Xc.T @ ec  # k×1
            B += np.outer(score, score)
        # Small-sample correction
        G = len(clusters)
        correction = G / (G - 1) * (n - 1) / (n - k)
        return B * correction

    # Reporter clustering
    B_rep = meat_matrix(reporters)
    V_rep = XtX_inv @ B_rep @ XtX_inv

    # Partner clustering
    B_par = meat_matrix(partners)
    V_par = XtX_inv @ B_par @ XtX_inv

    # Intersection (pair = reporter×partner)
    pair_ids = np.array([f"{r}_{p}" for r, p in zip(reporters, partners)])
    B_pair = meat_matrix(pair_ids)
    V_pair = XtX_inv @ B_pair @ XtX_inv

    # Two-way: V = V_rep + V_par - V_pair
    V_twoway = V_rep + V_par - V_pair

    # Ensure positive diagonal (numerical issues)
    se_twoway = np.sqrt(np.maximum(np.diag(V_twoway), 0))

    return se_twoway


def pair_fe_ols(df, dep_var, regressors, model_name):
    """
    OLS with pair fixed effects + year fixed effects.
    Demean by pair (within transformation), include year dummies.
    """
    years = sorted(df['year'].dropna().unique())
    yr_cols = [f'yr_{int(y)}' for y in years[1:]]
    yr_cols = [c for c in yr_cols if c in df.columns]

    all_vars = regressors + yr_cols
    est = df.dropna(subset=[dep_var] + all_vars + ['pair_id', 'year']).copy()

    if len(est) < 100:
        print(f"  {model_name}: insufficient obs ({len(est)})")
        return None

    # Within transformation (demean by pair)
    for col in [dep_var] + all_vars:
        pair_means = est.groupby('pair_id')[col].transform('mean')
        est[f'{col}_dm'] = est[col] - pair_means

    y_dm = est[f'{dep_var}_dm'].values
    X_dm = est[[f'{v}_dm' for v in regressors]].values  # Only demean non-year vars
    # Year dummies don't get demeaned separately — absorbed differently
    # Actually for pair FE + year FE, we demean by pair then include year dummies
    X_yr = est[yr_cols].values
    for j in range(X_yr.shape[1]):
        X_yr[:, j] = X_yr[:, j] - est.groupby('pair_id')[yr_cols[j]].transform('mean').values
    X_full = np.column_stack([X_dm, X_yr]) if X_yr.shape[1] > 0 else X_dm

    # OLS on demeaned data (no constant — absorbed by FE)
    result = sm.OLS(y_dm, X_full).fit()

    n_regressors = X_dm.shape[1]
    print(f"\n{'=' * 70}")
    print(f"  {model_name}")
    print(f"  N = {len(est):,}, Pairs = {est['pair_id'].nunique():,}")
    print(f"  R² (within) = {result.rsquared:.4f}")
    print(f"{'=' * 70}")
    print(f"  {'Variable':<30} {'Coef':>10} {'SE':>10} {'p-val':>8}")

    results_list = []
    for i, v in enumerate(regressors):
        sig = '***' if result.pvalues[i] < 0.01 else '**' if result.pvalues[i] < 0.05 else '*' if result.pvalues[i] < 0.1 else ''
        print(f"  {v:<30} {result.params[i]:>10.4f} {result.bse[i]:>10.4f} {result.pvalues[i]:>8.4f} {sig}")
        results_list.append({
            'model': model_name,
            'variable': v,
            'coefficient': result.params[i],
            'std_error': result.bse[i],
            't_stat': result.tvalues[i],
            'p_value': result.pvalues[i],
        })

    results_list.append({'model': model_name, 'variable': '_R_squared_within',
                         'coefficient': result.rsquared, 'std_error': np.nan,
                         't_stat': np.nan, 'p_value': np.nan})
    results_list.append({'model': model_name, 'variable': '_N_obs',
                         'coefficient': len(est), 'std_error': np.nan,
                         't_stat': np.nan, 'p_value': np.nan})
    results_list.append({'model': model_name, 'variable': '_N_pairs',
                         'coefficient': est['pair_id'].nunique(), 'std_error': np.nan,
                         't_stat': np.nan, 'p_value': np.nan})

    return pd.DataFrame(results_list)


def ppml_structural_gravity(df, model_name):
    """
    PPML with exporter×year and importer×year FE.

    Since country×year FE absorb level ΔZ, we identify through:
    - ΔZ × log(distance): does demographic distance matter more for distant pairs?
    - ΔZ × contiguity: does proximity amplify demographic flows?
    - ΔZ × common_language: does information reduce demographic barriers?

    Uses 10% random subsample for computational feasibility.
    """
    from statsmodels.genmod.generalized_linear_model import GLM
    from statsmodels.genmod.families import Poisson

    est = df.dropna(subset=['portfolio_total', 'dZ_1', 'dZ_2', 'dZ_3',
                            'log_dist', 'contiguity', 'common_lang_official',
                            'reporter', 'partner', 'year']).copy()

    # Keep only positive flows (PPML can handle zeros but FE explosion is the issue)
    # Actually include zeros for proper PPML
    est['flow'] = est['portfolio_total'].clip(lower=0) / 1e6  # Scale to millions

    # Create interaction terms
    est['dZ1_x_dist'] = est['dZ_1'] * est['log_dist']
    est['dZ2_x_dist'] = est['dZ_2'] * est['log_dist']
    est['dZ3_x_dist'] = est['dZ_3'] * est['log_dist']
    est['dZ1_x_contig'] = est['dZ_1'] * est['contiguity']
    est['dZ1_x_lang'] = est['dZ_1'] * est['common_lang_official']

    # Subsample (10% for speed with high-dimensional FE)
    np.random.seed(42)
    pairs = est['pair_id'].unique()
    sample_pairs = np.random.choice(pairs, size=max(len(pairs) // 10, 500), replace=False)
    est = est[est['pair_id'].isin(sample_pairs)].copy()

    print(f"\n{'=' * 70}")
    print(f"  {model_name}")
    print(f"  N = {len(est):,}, Pairs = {est['pair_id'].nunique():,}")

    # Create exporter×year and importer×year dummies
    est['exp_yr'] = est['reporter'] + '_' + est['year'].astype(int).astype(str)
    est['imp_yr'] = est['partner'] + '_' + est['year'].astype(int).astype(str)

    # Use dummies for the most common exp_yr/imp_yr (drop rare ones)
    exp_yr_counts = est['exp_yr'].value_counts()
    imp_yr_counts = est['imp_yr'].value_counts()

    # Keep FE with at least 5 observations
    keep_exp = exp_yr_counts[exp_yr_counts >= 5].index
    keep_imp = imp_yr_counts[imp_yr_counts >= 5].index
    est = est[est['exp_yr'].isin(keep_exp) & est['imp_yr'].isin(keep_imp)].copy()

    if len(est) < 500:
        print(f"  Insufficient obs after FE filter ({len(est)})")
        return None

    # Create FE dummies — use pandas get_dummies with drop_first
    exp_dummies = pd.get_dummies(est['exp_yr'], prefix='exp', drop_first=True, dtype=float)
    imp_dummies = pd.get_dummies(est['imp_yr'], prefix='imp', drop_first=True, dtype=float)

    # Check feasibility
    n_fe = exp_dummies.shape[1] + imp_dummies.shape[1]
    print(f"  Exporter×year FE: {exp_dummies.shape[1]}, Importer×year FE: {imp_dummies.shape[1]}")

    if n_fe > 2000:
        print(f"  Too many FE ({n_fe}) — further subsampling")
        pairs2 = est['pair_id'].unique()
        sample_pairs2 = np.random.choice(pairs2, size=max(len(pairs2) // 3, 300), replace=False)
        est = est[est['pair_id'].isin(sample_pairs2)].copy()
        exp_dummies = pd.get_dummies(est['exp_yr'], prefix='exp', drop_first=True, dtype=float)
        imp_dummies = pd.get_dummies(est['imp_yr'], prefix='imp', drop_first=True, dtype=float)
        n_fe = exp_dummies.shape[1] + imp_dummies.shape[1]
        print(f"  After re-subsample: N={len(est):,}, FE={n_fe}")

    # Regressors: interactions only (levels absorbed by FE)
    interaction_vars = ['dZ1_x_dist', 'dZ2_x_dist', 'dZ3_x_dist',
                        'dZ1_x_contig', 'dZ1_x_lang']

    X_vars = est[interaction_vars].values
    X_fe = np.column_stack([X_vars, exp_dummies.values, imp_dummies.values])

    # Add constant
    X_full = sm.add_constant(X_fe)
    y = est['flow'].values

    print(f"  Final: N={len(est):,}, regressors={X_full.shape[1]}")
    print(f"  Running PPML (this may take a minute)...")

    try:
        poisson = GLM(y, X_full, family=Poisson())
        result = poisson.fit(maxiter=100, method='IRLS')

        print(f"  Converged. Deviance={result.deviance:.1f}")
        print(f"  {'Variable':<25} {'Coef':>10} {'SE':>10} {'p-val':>8}")

        results_list = []
        for i, v in enumerate(interaction_vars):
            idx = i + 1  # +1 for constant
            sig = '***' if result.pvalues[idx] < 0.01 else '**' if result.pvalues[idx] < 0.05 else '*' if result.pvalues[idx] < 0.1 else ''
            print(f"  {v:<25} {result.params[idx]:>10.4f} {result.bse[idx]:>10.4f} {result.pvalues[idx]:>8.4f} {sig}")
            results_list.append({
                'model': model_name,
                'variable': v,
                'coefficient': result.params[idx],
                'std_error': result.bse[idx],
                't_stat': result.tvalues[idx],
                'p_value': result.pvalues[idx],
            })

        results_list.append({'model': model_name, 'variable': '_N_obs',
                             'coefficient': len(est), 'std_error': np.nan,
                             't_stat': np.nan, 'p_value': np.nan})
        results_list.append({'model': model_name, 'variable': '_N_pairs',
                             'coefficient': est['pair_id'].nunique(), 'std_error': np.nan,
                             't_stat': np.nan, 'p_value': np.nan})

        return pd.DataFrame(results_list)

    except Exception as e:
        print(f"  PPML failed: {e}")
        return None


def main():
    print("=" * 70)
    print("PHASE 7: REFEREE-REQUESTED ROBUSTNESS")
    print("=" * 70)

    df = pd.read_csv(PROCESSED_DIR / "bilateral_panel.csv")
    print(f"Loaded: {len(df):,} obs")

    # Year dummies
    years = sorted(df['year'].dropna().unique())
    for y in years[1:]:
        df[f'yr_{int(y)}'] = (df['year'] == y).astype(int)

    dep_var = 'log_portfolio_total'
    gravity_vars = ['log_dist', 'contiguity', 'common_lang_official', 'colonial_ties', 'log_gdp_product']
    demo_vars = ['dZ_1', 'dZ_2', 'dZ_3']
    kaopen_vars = ['kaopen_j', 'dZ_1_x_kaopen_j', 'dZ_2_x_kaopen_j', 'dZ_3_x_kaopen_j']

    all_results = []

    # ===================================================================
    # A. TWO-WAY CLUSTERED SEs
    # ===================================================================
    print("\n" + "=" * 70)
    print("A. TWO-WAY CLUSTERED STANDARD ERRORS")
    print("=" * 70)

    for model_name, regressors in [("2b: Two-way clustered", gravity_vars + demo_vars),
                                    ("2c: Two-way clustered", gravity_vars + demo_vars + kaopen_vars)]:
        yr_cols = [f'yr_{int(y)}' for y in years[1:]]
        yr_cols = [c for c in yr_cols if c in df.columns]
        all_vars = regressors + yr_cols

        est = df.dropna(subset=[dep_var] + all_vars + ['pair_id', 'year', 'reporter', 'partner']).copy()
        if len(est) < 100:
            continue

        # Run pooled GLS for coefficients
        gls = PanelGLS()
        gls.fit(est[dep_var].values, est[all_vars].values,
                est['pair_id'].values, est['year'].values.astype(int))

        # OLS with statsmodels cluster-robust covariance (reporter clustering)
        X_const = sm.add_constant(est[regressors].values)  # no year dummies — fewer cols, avoid singularity
        y_vals = est[dep_var].values

        try:
            # Reporter-clustered
            groups_rep = pd.Categorical(est['reporter']).codes
            ols_rep = sm.OLS(y_vals, X_const).fit(cov_type='cluster', cov_kwds={'groups': groups_rep})

            # Partner-clustered
            groups_par = pd.Categorical(est['partner']).codes
            ols_par = sm.OLS(y_vals, X_const).fit(cov_type='cluster', cov_kwds={'groups': groups_par})

            # Two-way: use Cameron-Gelbach-Miller via manual computation
            # V = V_rep + V_par - V_pair (pair = reporter×partner)
            groups_pair = pd.Categorical(est['pair_id']).codes
            ols_pair = sm.OLS(y_vals, X_const).fit(cov_type='cluster', cov_kwds={'groups': groups_pair})

            # Two-way SE = sqrt(se_rep² + se_par² - se_pair²) approximately
            se_twoway = np.sqrt(np.maximum(
                ols_rep.bse[1:]**2 + ols_par.bse[1:]**2 - ols_pair.bse[1:]**2, 0))

            ols_base = sm.OLS(y_vals, X_const).fit()

            print(f"\n  {model_name}")
            print(f"  N = {len(est):,}")
            print(f"  {'Variable':<30} {'OLS coef':>10} {'GLS SE':>10} {'2way SE':>10} {'2way p':>8} {'Ratio':>7}")

            results_list = []
            for i, v in enumerate(regressors):
                coef = ols_base.params[i + 1]
                se_gls = gls.se[i]
                se_cl = se_twoway[i]
                t_cl = coef / se_cl if se_cl > 0 else np.nan
                p_cl = 2 * (1 - stats.t.cdf(abs(t_cl), len(est) - len(regressors) - 1)) if not np.isnan(t_cl) else np.nan
                ratio = se_cl / se_gls if se_gls > 0 else np.nan

                sig = '***' if p_cl < 0.01 else '**' if p_cl < 0.05 else '*' if p_cl < 0.1 else ''
                print(f"  {v:<30} {coef:>10.4f} {se_gls:>10.4f} {se_cl:>10.4f} {p_cl:>8.4f} {ratio:>6.2f}x {sig}")

                results_list.append({
                    'model': model_name,
                    'variable': v,
                    'coefficient': coef,
                    'std_error_gls': se_gls,
                    'std_error': se_cl,
                    't_stat': t_cl,
                    'p_value': p_cl,
                })

            # Also report reporter-only clustered for comparison
            print(f"\n  Reporter-only clustered SEs:")
            for i, v in enumerate(regressors):
                coef = ols_base.params[i + 1]
                se_rep = ols_rep.bse[i + 1]
                t_rep = coef / se_rep if se_rep > 0 else np.nan
                p_rep = 2 * (1 - stats.t.cdf(abs(t_rep), len(est) - len(regressors) - 1)) if not np.isnan(t_rep) else np.nan
                sig = '***' if p_rep < 0.01 else '**' if p_rep < 0.05 else '*' if p_rep < 0.1 else ''
                print(f"  {v:<30} {se_rep:>10.4f} {p_rep:>8.4f} {sig}")

                results_list.append({
                    'model': model_name + ' (reporter only)',
                    'variable': v,
                    'coefficient': coef,
                    'std_error': se_rep,
                    't_stat': t_rep,
                    'p_value': p_rep,
                })

            results_list.append({'model': model_name, 'variable': '_N_obs',
                                 'coefficient': len(est), 'std_error': np.nan,
                                 't_stat': np.nan, 'p_value': np.nan})
            all_results.append(pd.DataFrame(results_list))

        except Exception as e:
            print(f"  {model_name}: clustering failed: {e}")

    # ===================================================================
    # B. PAIR FE + YEAR FE
    # ===================================================================
    print("\n" + "=" * 70)
    print("B. PAIR FIXED EFFECTS + YEAR FE")
    print("=" * 70)

    # Only time-varying regressors survive pair FE
    # log_dist, contiguity, common_lang, colonial_ties are time-invariant → absorbed
    # log_gdp_product varies over time, as do dZ_1, dZ_2, dZ_3, kaopen_j
    time_varying = ['log_gdp_product'] + demo_vars

    res_fe = pair_fe_ols(df, dep_var, time_varying, "2b: Pair FE + Year FE")
    if res_fe is not None:
        all_results.append(res_fe)

    time_varying_c = ['log_gdp_product'] + demo_vars + kaopen_vars
    res_fe_c = pair_fe_ols(df, dep_var, time_varying_c, "2c: Pair FE + Year FE")
    if res_fe_c is not None:
        all_results.append(res_fe_c)

    # ===================================================================
    # C. PPML WITH EXPORTER×YEAR AND IMPORTER×YEAR FE
    # ===================================================================
    print("\n" + "=" * 70)
    print("C. STRUCTURAL GRAVITY PPML (Country×Year FE)")
    print("=" * 70)

    res_ppml = ppml_structural_gravity(df, "PPML: Structural gravity (ΔZ × bilateral)")
    if res_ppml is not None:
        all_results.append(res_ppml)

    # ===================================================================
    # D. ZERO ACCOUNTING
    # ===================================================================
    print("\n" + "=" * 70)
    print("D. ZERO ACCOUNTING")
    print("=" * 70)

    for flow_type in ['portfolio_total', 'portfolio_equity', 'portfolio_debt', 'fdi_outward']:
        if flow_type not in df.columns:
            continue
        total = df[flow_type].notna().sum()
        positive = (df[flow_type] > 0).sum()
        zero = ((df[flow_type] == 0) | (df[flow_type].isna())).sum() - df[flow_type].isna().sum()
        zeros_explicit = (df[flow_type] == 0).sum()
        log_col = f'log_{flow_type}'
        log_avail = df[log_col].notna().sum() if log_col in df.columns else 0

        pct_lost = (1 - log_avail / total) * 100 if total > 0 else 0
        print(f"  {flow_type}:")
        print(f"    Total non-null: {total:,}")
        print(f"    Positive: {positive:,}")
        print(f"    Zero: {zeros_explicit:,} ({zeros_explicit/total*100:.1f}%)")
        print(f"    Log-transformed available: {log_avail:,}")
        print(f"    Lost in log: {total - log_avail:,} ({pct_lost:.1f}%)")

    # ===================================================================
    # E. ECONOMIC MAGNITUDES
    # ===================================================================
    print("\n" + "=" * 70)
    print("E. ECONOMIC MAGNITUDES")
    print("=" * 70)

    est = df.dropna(subset=[dep_var] + gravity_vars + demo_vars).copy()

    for v in demo_vars:
        sd = est[v].std()
        p25 = est[v].quantile(0.25)
        p75 = est[v].quantile(0.75)
        iqr = p75 - p25
        print(f"  {v}: SD={sd:.3f}, IQR={iqr:.3f} (p25={p25:.3f}, p75={p75:.3f})")

    # Using Model 2b coefficients
    print(f"\n  Model 2b magnitudes (1 SD increase in ΔZ):")
    coefs_2b = {'dZ_1': 3.675, 'dZ_2': -0.489, 'dZ_3': 0.019}
    for v, coef in coefs_2b.items():
        sd = est[v].std()
        effect_log = coef * sd
        effect_pct = (np.exp(effect_log) - 1) * 100
        print(f"    {v}: {coef:.3f} × {sd:.3f} = {effect_log:.3f} log points = {effect_pct:+.1f}% bilateral holdings")

    print(f"\n  Model 2b magnitudes (25th to 75th percentile of ΔZ):")
    for v, coef in coefs_2b.items():
        iqr = est[v].quantile(0.75) - est[v].quantile(0.25)
        effect_log = coef * iqr
        effect_pct = (np.exp(effect_log) - 1) * 100
        print(f"    {v}: {coef:.3f} × {iqr:.3f} = {effect_log:.3f} log points = {effect_pct:+.1f}% bilateral holdings")

    # Combined Z effect
    print(f"\n  Combined demographic distance effect (all three ΔZ, 25th→75th):")
    total_effect = 0
    for v, coef in coefs_2b.items():
        iqr = est[v].quantile(0.75) - est[v].quantile(0.25)
        total_effect += coef * iqr
    total_pct = (np.exp(total_effect) - 1) * 100
    print(f"    Total: {total_effect:.3f} log points = {total_pct:+.1f}% bilateral holdings")

    # ===================================================================
    # SAVE ALL RESULTS
    # ===================================================================
    if all_results:
        results_df = pd.concat(all_results, ignore_index=True)
        outfile = OUTPUT_DIR / "referee_robustness.csv"
        results_df.to_csv(outfile, index=False)
        print(f"\n  Saved: {outfile}")

    print("\n" + "=" * 70)
    print("PHASE 7 COMPLETE")
    print("=" * 70)


if __name__ == "__main__":
    main()
